// Copyright Epic Games, Inc. All Rights Reserved.

#include "zencore/stats.h"

#include <zencore/compactbinarybuilder.h>
#include "zencore/intmath.h"
#include "zencore/thread.h"
#include "zencore/timer.h"

#include <cmath>
#include <gsl/gsl-lite.hpp>

#if ZEN_WITH_TESTS
#	include <zencore/testing.h>
#endif

//
// Derived from https://github.com/dln/medida/blob/master/src/medida/stats/ewma.cc
//

namespace zen::metrics {

static constinit int	kTickIntervalInSeconds = 5;
static constinit double kSecondsPerMinute	   = 60.0;
static constinit int	kOneMinute			   = 1;
static constinit int	kFiveMinutes		   = 5;
static constinit int	kFifteenMinutes		   = 15;

static const double kM1_ALPHA  = 1.0 - std::exp(-kTickIntervalInSeconds / kSecondsPerMinute / kOneMinute);
static const double kM5_ALPHA  = 1.0 - std::exp(-kTickIntervalInSeconds / kSecondsPerMinute / kFiveMinutes);
static const double kM15_ALPHA = 1.0 - std::exp(-kTickIntervalInSeconds / kSecondsPerMinute / kFifteenMinutes);

static const uint64_t CountPerTick	 = GetHifreqTimerFrequencySafe() * kTickIntervalInSeconds;
static const uint64_t CountPerSecond = GetHifreqTimerFrequencySafe();

//////////////////////////////////////////////////////////////////////////

void
RawEWMA::Tick(double Alpha, uint64_t Interval, uint64_t Count, bool IsInitialUpdate)
{
	const double InstantRate = double(Count) / Interval;

	if (IsInitialUpdate)
	{
		m_Rate.store(InstantRate, std::memory_order_release);
	}
	else
	{
		double Delta = Alpha * (InstantRate - m_Rate);

#if defined(__cpp_lib_atomic_float)
		m_Rate.fetch_add(Delta);
#else
		double Value = m_Rate.load(std::memory_order_acquire);
		double Next;
		do
		{
			Next = Value + Delta;
		} while (!m_Rate.compare_exchange_weak(Value, Next, std::memory_order_relaxed));
#endif
	}
}

double
RawEWMA::Rate() const
{
	return m_Rate.load(std::memory_order_relaxed) * CountPerSecond;
}

//////////////////////////////////////////////////////////////////////////

Meter::Meter() : m_StartTick{GetHifreqTimerValue()}, m_LastTick(m_StartTick.load())
{
}

Meter::~Meter()
{
}

void
Meter::TickIfNecessary()
{
	uint64_t	   OldTick = m_LastTick.load();
	const uint64_t NewTick = GetHifreqTimerValue();
	const uint64_t Age	   = NewTick - OldTick;

	if (Age > CountPerTick)
	{
		// Ensure only one thread at a time updates the time. This
		// works because our tick interval should be sufficiently
		// long to ensure two threads don't end up inside this block

		if (m_LastTick.compare_exchange_strong(OldTick, NewTick))
		{
			m_Remainder.fetch_add(Age);

			do
			{
				int64_t Remain = m_Remainder.load(std::memory_order_relaxed);

				if (Remain < 0)
				{
					return;
				}

				if (m_Remainder.compare_exchange_strong(Remain, Remain - CountPerTick))
				{
					Tick();
				}
			} while (true);
		}
	}
}

void
Meter::Tick()
{
	const uint64_t PendingCount = m_PendingCount.exchange(0);
	const bool	   IsFirstTick	= m_IsFirstTick;

	if (IsFirstTick)
	{
		m_IsFirstTick = false;
	}

	m_RateM1.Tick(kM1_ALPHA, CountPerTick, PendingCount, IsFirstTick);
	m_RateM5.Tick(kM5_ALPHA, CountPerTick, PendingCount, IsFirstTick);
	m_RateM15.Tick(kM15_ALPHA, CountPerTick, PendingCount, IsFirstTick);
}

double
Meter::Rate1()
{
	TickIfNecessary();

	return m_RateM1.Rate();
}

double
Meter::Rate5()
{
	TickIfNecessary();

	return m_RateM5.Rate();
}

double
Meter::Rate15()
{
	TickIfNecessary();

	return m_RateM15.Rate();
}

double
Meter::MeanRate() const
{
	const uint64_t Count = m_TotalCount.load(std::memory_order_relaxed);

	if (Count == 0)
	{
		return 0.0;
	}

	const uint64_t Elapsed = GetHifreqTimerValue() - m_StartTick;

	return (double(Count) * GetHifreqTimerFrequency()) / Elapsed;
}

void
Meter::Mark(uint64_t Count)
{
	TickIfNecessary();

	m_TotalCount.fetch_add(Count);
	m_PendingCount.fetch_add(Count);
}

//////////////////////////////////////////////////////////////////////////

uint64_t
rol64(uint64_t x, int k)
{
	return (x << k) | (x >> (64 - k));
}

struct xoshiro256ss_state
{
	uint64_t s[4];
};

uint64_t
xoshiro256ss(struct xoshiro256ss_state* state)
{
	uint64_t*	   s	  = state->s;
	uint64_t const result = rol64(s[1] * 5, 7) * 9;
	uint64_t const t	  = s[1] << 17;

	s[2] ^= s[0];
	s[3] ^= s[1];
	s[1] ^= s[2];
	s[0] ^= s[3];

	s[2] ^= t;
	s[3] = rol64(s[3], 45);

	return result;
}

class xoshiro256
{
public:
	uint64_t				  operator()() { return xoshiro256ss(&m_State); }
	static constexpr uint64_t min() { return 0; }
	static constexpr uint64_t max() { return ~(0ull); }

private:
	xoshiro256ss_state m_State{0xf0fefaf9, 0xbeeb5238, 0x48472397, 0x58858558};
};

thread_local xoshiro256 ThreadLocalRng;

//////////////////////////////////////////////////////////////////////////

UniformSample::UniformSample(uint32_t ReservoirSize) : m_Values(ReservoirSize)
{
}

UniformSample::~UniformSample()
{
}

void
UniformSample::Clear()
{
	for (auto& Value : m_Values)
	{
		Value.store(0);
	}
	m_SampleCounter = 0;
}

uint32_t
UniformSample::Size() const
{
	return gsl::narrow_cast<uint32_t>(Min(m_SampleCounter.load(), m_Values.size()));
}

void
UniformSample::Update(int64_t Value)
{
	const uint64_t Count = m_SampleCounter++;
	const uint64_t Size	 = m_Values.size();

	if (Count < Size)
	{
		m_Values[Count] = Value;
	}
	else
	{
		// Randomly choose an old entry to potentially replace (the probability
		// of replacing an entry diminishes with time)

		const uint64_t SampleIndex = ThreadLocalRng() % Count;

		if (SampleIndex < Size)
		{
			m_Values[SampleIndex].store(Value, std::memory_order_release);
		}
	}
}

SampleSnapshot
UniformSample::Snapshot() const
{
	uint64_t			ValuesSize = Size();
	std::vector<double> Values(ValuesSize);

	for (int i = 0, n = int(ValuesSize); i < n; ++i)
	{
		Values[i] = double(m_Values[i]);
	}

	return SampleSnapshot(std::move(Values));
}

//////////////////////////////////////////////////////////////////////////

Histogram::Histogram(int32_t SampleCount) : m_Sample(SampleCount)
{
}

Histogram::~Histogram()
{
}

void
Histogram::Clear()
{
	m_Min = m_Max = m_Sum = m_Count = 0;
	m_Sample.Clear();
}

void
Histogram::Update(int64_t Value)
{
	m_Sample.Update(Value);

	if (m_Count == 0)
	{
		m_Min = m_Max = Value;
	}
	else
	{
		int64_t CurrentMax = m_Max.load(std::memory_order_relaxed);

		while ((CurrentMax < Value) && !m_Max.compare_exchange_weak(CurrentMax, Value))
		{
		}

		int64_t CurrentMin = m_Min.load(std::memory_order_relaxed);

		while ((CurrentMin > Value) && !m_Min.compare_exchange_weak(CurrentMin, Value))
		{
		}
	}

	m_Sum += Value;
	++m_Count;
}

int64_t
Histogram::Max() const
{
	return m_Max.load(std::memory_order_relaxed);
}

int64_t
Histogram::Min() const
{
	return m_Min.load(std::memory_order_relaxed);
}

double
Histogram::Mean() const
{
	if (m_Count)
	{
		return double(m_Sum.load(std::memory_order_relaxed)) / m_Count;
	}
	else
	{
		return 0.0;
	}
}

uint64_t
Histogram::Count() const
{
	return m_Count.load(std::memory_order_relaxed);
}

//////////////////////////////////////////////////////////////////////////

SampleSnapshot::SampleSnapshot(std::vector<double>&& Values) : m_Values(std::move(Values))
{
	std::sort(begin(m_Values), end(m_Values));
}

SampleSnapshot::~SampleSnapshot()
{
}

double
SampleSnapshot::GetQuantileValue(double Quantile)
{
	ZEN_ASSERT((Quantile >= 0.0) && (Quantile <= 1.0));

	if (m_Values.empty())
	{
		return 0.0;
	}

	const double Pos = Quantile * (m_Values.size() + 1);

	if (Pos < 1)
	{
		return m_Values.front();
	}

	if (Pos >= m_Values.size())
	{
		return m_Values.back();
	}

	const int32_t Index = (int32_t)Pos;
	const double  Lower = m_Values[Index - 1];
	const double  Upper = m_Values[Index];

	// Lerp
	return Lower + (Pos - std::floor(Pos)) * (Upper - Lower);
}

const std::vector<double>&
SampleSnapshot::GetValues() const
{
	return m_Values;
}

//////////////////////////////////////////////////////////////////////////

OperationTiming::OperationTiming(int32_t SampleCount) : m_Histogram{SampleCount}
{
}

OperationTiming::~OperationTiming()
{
}

void
OperationTiming::Update(int64_t Duration)
{
	m_Meter.Mark(1);
	m_Histogram.Update(Duration);
}

int64_t
OperationTiming::Max() const
{
	return m_Histogram.Max();
}

int64_t
OperationTiming::Min() const
{
	return m_Histogram.Min();
}

double
OperationTiming::Mean() const
{
	return m_Histogram.Mean();
}

uint64_t
OperationTiming::Count() const
{
	return m_Meter.Count();
}

OperationTiming::Scope::Scope(OperationTiming& Outer) : m_Outer(Outer), m_StartTick(GetHifreqTimerValue())
{
}

OperationTiming::Scope::~Scope()
{
	Stop();
}

void
OperationTiming::Scope::Stop()
{
	if (m_StartTick != 0)
	{
		m_Outer.Update(GetHifreqTimerValue() - m_StartTick);
		m_StartTick = 0;
	}
}

void
OperationTiming::Scope::Cancel()
{
	m_StartTick = 0;
}

//////////////////////////////////////////////////////////////////////////

RequestStats::RequestStats(int32_t SampleCount) : m_RequestTimeHistogram{SampleCount}, m_BytesHistogram{SampleCount}
{
}

RequestStats::~RequestStats()
{
}

void
RequestStats::Update(int64_t Duration, int64_t Bytes)
{
	m_RequestMeter.Mark(1);
	m_RequestTimeHistogram.Update(Duration);

	m_BytesMeter.Mark(Bytes);
	m_BytesHistogram.Update(Bytes);
}

uint64_t
RequestStats::Count() const
{
	return m_RequestMeter.Count();
}

RequestStats::Scope::Scope(RequestStats& Outer, int64_t Bytes) : m_Outer(Outer), m_StartTick(GetHifreqTimerValue()), m_Bytes(Bytes)
{
}

RequestStats::Scope::~Scope()
{
	Stop();
}

void
RequestStats::Scope::Stop()
{
	if (m_StartTick != 0)
	{
		m_Outer.Update(GetHifreqTimerValue() - m_StartTick, m_Bytes);
		m_StartTick = 0;
	}
}

void
RequestStats::Scope::Cancel()
{
	m_StartTick = 0;
}

//////////////////////////////////////////////////////////////////////////

void
EmitSnapshot(Meter& Stat, CbObjectWriter& Cbo)
{
	Cbo << "count" << Stat.Count();
	Cbo << "rate_mean" << Stat.MeanRate();
	Cbo << "rate_1" << Stat.Rate1() << "rate_5" << Stat.Rate5() << "rate_15" << Stat.Rate15();
}

void
RequestStats::EmitSnapshot(std::string_view Tag, CbObjectWriter& Cbo)
{
	Cbo.BeginObject(Tag);

	Cbo.BeginObject("requests");
	metrics::EmitSnapshot(m_RequestMeter, Cbo);
	metrics::EmitSnapshot(m_RequestTimeHistogram, Cbo, GetHifreqTimerToSeconds());
	Cbo.EndObject();

	Cbo.BeginObject("bytes");
	metrics::EmitSnapshot(m_BytesMeter, Cbo);
	metrics::EmitSnapshot(m_BytesHistogram, Cbo, 1.0);
	Cbo.EndObject();

	Cbo.EndObject();
}

RequestStatsSnapshot
RequestStats::Snapshot()
{
	const double ToSeconds = GetHifreqTimerToSeconds();

	return RequestStatsSnapshot{.Requests = GetSnapshot(m_RequestMeter, m_RequestTimeHistogram, ToSeconds),
								.Bytes	  = GetSnapshot(m_BytesMeter, m_BytesHistogram, 1.0)};
}

void
EmitSnapshot(std::string_view Tag, OperationTiming& Stat, CbObjectWriter& Cbo)
{
	Cbo.BeginObject(Tag);

	SampleSnapshot Snap = Stat.Snapshot();

	Cbo << "count" << Stat.Count();
	Cbo << "rate_mean" << Stat.MeanRate();
	Cbo << "rate_1" << Stat.Rate1() << "rate_5" << Stat.Rate5() << "rate_15" << Stat.Rate15();

	const double ToSeconds = GetHifreqTimerToSeconds();

	Cbo << "t_avg" << Stat.Mean() * ToSeconds;
	Cbo << "t_min" << Stat.Min() * ToSeconds << "t_max" << Stat.Max() * ToSeconds;
	Cbo << "t_p75" << Snap.Get75Percentile() * ToSeconds << "t_p95" << Snap.Get95Percentile() * ToSeconds << "t_p99"
		<< Snap.Get99Percentile() * ToSeconds << "t_p999" << Snap.Get999Percentile() * ToSeconds;

	Cbo.EndObject();
}

void
EmitSnapshot(std::string_view Tag, const Histogram& Stat, CbObjectWriter& Cbo, double ConversionFactor)
{
	Cbo.BeginObject(Tag);
	EmitSnapshot(Stat, Cbo, ConversionFactor);
	Cbo.EndObject();
}

void
EmitSnapshot(const Histogram& Stat, CbObjectWriter& Cbo, double ConversionFactor)
{
	SampleSnapshot Snap = Stat.Snapshot();

	Cbo << "count" << Stat.Count() * ConversionFactor << "avg" << Stat.Mean() * ConversionFactor;
	Cbo << "min" << Stat.Min() * ConversionFactor << "max" << Stat.Max() * ConversionFactor;
	Cbo << "p75" << Snap.Get75Percentile() * ConversionFactor << "p95" << Snap.Get95Percentile() * ConversionFactor << "p99"
		<< Snap.Get99Percentile() * ConversionFactor << "p999" << Snap.Get999Percentile() * ConversionFactor;
}

void
EmitSnapshot(std::string_view Tag, Meter& Stat, CbObjectWriter& Cbo)
{
	Cbo.BeginObject(Tag);

	Cbo << "count" << Stat.Count() << "rate_mean" << Stat.MeanRate();
	Cbo << "rate_1" << Stat.Rate1() << "rate_5" << Stat.Rate5() << "rate_15" << Stat.Rate15();

	Cbo.EndObject();
}

void
EmitSnapshot(const MeterSnapshot& Snapshot, CbObjectWriter& Cbo)
{
	Cbo << "count" << Snapshot.Count;
	Cbo << "rate_mean" << Snapshot.MeanRate;
	Cbo << "rate_1" << Snapshot.Rate1 << "rate_5" << Snapshot.Rate5 << "rate_15" << Snapshot.Rate15;
}

void
EmitSnapshot(const HistogramSnapshot& Snapshot, CbObjectWriter& Cbo)
{
	Cbo << "t_count" << Snapshot.Count << "t_avg" << Snapshot.Avg;
	Cbo << "t_min" << Snapshot.Min << "t_max" << Snapshot.Max;
	Cbo << "t_p75" << Snapshot.P75 << "t_p95" << Snapshot.P95 << "t_p99" << Snapshot.P999;
}

void
EmitSnapshot(std::string_view Tag, const StatsSnapshot& Snapshot, CbObjectWriter& Cbo)
{
	Cbo.BeginObject(Tag);
	EmitSnapshot(Snapshot.Meter, Cbo);
	EmitSnapshot(Snapshot.Histogram, Cbo);
	Cbo.EndObject();
}

void
EmitSnapshot(std::string_view Tag, const RequestStatsSnapshot& Snapshot, CbObjectWriter& Cbo)
{
	if (Snapshot.Requests.Meter.Count == 0)
	{
		return;
	}
	Cbo.BeginObject(Tag);
	EmitSnapshot("request", Snapshot.Requests, Cbo);
	EmitSnapshot("bytes", Snapshot.Bytes, Cbo);
	Cbo.EndObject();
}

//////////////////////////////////////////////////////////////////////////

#if ZEN_WITH_TESTS

TEST_CASE("Core.Stats.Histogram")
{
	Histogram Histo{258};

	SampleSnapshot Snap1 = Histo.Snapshot();
	CHECK_EQ(Snap1.Size(), 0);
	CHECK_EQ(Snap1.GetMedian(), 0);

	Histo.Update(1);
	CHECK_EQ(Histo.Min(), 1);
	CHECK_EQ(Histo.Max(), 1);

	SampleSnapshot Snap2 = Histo.Snapshot();
	CHECK_EQ(Snap2.Size(), 1);

	Histo.Update(2);
	CHECK_EQ(Histo.Min(), 1);
	CHECK_EQ(Histo.Max(), 2);

	SampleSnapshot Snap3 = Histo.Snapshot();
	CHECK_EQ(Snap3.Size(), 2);

	Histo.Update(-2);
	CHECK_EQ(Histo.Min(), -2);
	CHECK_EQ(Histo.Max(), 2);
	CHECK_EQ(Histo.Mean(), 1 / 3.0);

	SampleSnapshot Snap4 = Histo.Snapshot();
	CHECK_EQ(Snap4.Size(), 3);
	CHECK_EQ(Snap4.GetMedian(), 1);
	CHECK_EQ(Snap4.Get999Percentile(), 2);
	CHECK_EQ(Snap4.GetQuantileValue(0), -2);
}

TEST_CASE("Core.Stats.UniformSample")
{
	UniformSample Sample1{100};

	for (int i = 0; i < 100; ++i)
	{
		for (int j = 1; j <= 100; ++j)
		{
			Sample1.Update(j);
		}
	}

	int64_t Sum	  = 0;
	int64_t Count = 0;

	Sample1.IterateValues([&](int64_t Value) {
		++Count;
		Sum += Value;
	});

	double Average = double(Sum) / Count;

	CHECK(fabs(Average - 50) < 10);	 // What's the right test here? The result could vary massively and still be technically correct
}

TEST_CASE("Core.Stats.EWMA")
{
	SUBCASE("Simple_1")
	{
		RawEWMA Ewma1;
		Ewma1.Tick(kM1_ALPHA, CountPerSecond, 5, true);

		CHECK(fabs(Ewma1.Rate() - 5) < 0.1);

		for (int i = 0; i < 60; ++i)
		{
			Ewma1.Tick(kM1_ALPHA, CountPerSecond, 10, false);
		}

		CHECK(fabs(Ewma1.Rate() - 10) < 0.1);

		for (int i = 0; i < 60; ++i)
		{
			Ewma1.Tick(kM1_ALPHA, CountPerSecond, 20, false);
		}

		CHECK(fabs(Ewma1.Rate() - 20) < 0.1);
	}

	SUBCASE("Simple_10")
	{
		RawEWMA Ewma1;
		RawEWMA Ewma5;
		RawEWMA Ewma15;
		Ewma1.Tick(kM1_ALPHA, CountPerSecond, 5, true);
		Ewma5.Tick(kM5_ALPHA, CountPerSecond, 5, true);
		Ewma15.Tick(kM15_ALPHA, CountPerSecond, 5, true);

		CHECK(fabs(Ewma1.Rate() - 5) < 0.1);
		CHECK(fabs(Ewma5.Rate() - 5) < 0.1);
		CHECK(fabs(Ewma15.Rate() - 5) < 0.1);

		auto Tick1	= [&Ewma1](auto Value) { Ewma1.Tick(kM1_ALPHA, CountPerSecond, Value, false); };
		auto Tick5	= [&Ewma5](auto Value) { Ewma5.Tick(kM5_ALPHA, CountPerSecond, Value, false); };
		auto Tick15 = [&Ewma15](auto Value) { Ewma15.Tick(kM15_ALPHA, CountPerSecond, Value, false); };

		for (int i = 0; i < 60; ++i)
		{
			Tick1(10);
			Tick5(10);
			Tick15(10);
		}

		CHECK(fabs(Ewma1.Rate() - 10) < 0.1);

		for (int i = 0; i < 5 * 60; ++i)
		{
			Tick1(20);
			Tick5(20);
			Tick15(20);
		}

		CHECK(fabs(Ewma1.Rate() - 20) < 0.1);
		CHECK(fabs(Ewma5.Rate() - 20) < 0.1);

		for (int i = 0; i < 16 * 60; ++i)
		{
			Tick1(100);
			Tick5(100);
			Tick15(100);
		}

		CHECK(fabs(Ewma1.Rate() - 100) < 0.1);
		CHECK(fabs(Ewma5.Rate() - 100) < 0.1);
		CHECK(fabs(Ewma15.Rate() - 100) < 0.5);
	}
}

#	if 0  // This is not really a unit test, but mildly useful to exercise some code
TEST_CASE("Meter")
{
	Meter Meter1;
	Meter1.Mark(1);
	Sleep(1000);
	Meter1.Mark(1);
	Sleep(1000);
	Meter1.Mark(1);
	Sleep(1000);
	Meter1.Mark(1);
	Sleep(1000);
	Meter1.Mark(1);
	Sleep(1000);
	Meter1.Mark(1);
	Sleep(1000);
	Meter1.Mark(1);
	Sleep(1000);
	Meter1.Mark(1);
	Sleep(1000);
	Meter1.Mark(1);
	Sleep(1000);
	[[maybe_unused]] double Rate = Meter1.MeanRate();
}
#	endif
}

namespace zen {

void
stats_forcelink()
{
}

#endif

}  // namespace zen::metrics
